Attention Mechanism Comparison¶
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
import pandas as pd
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
from transformers import (
AutoTokenizer, AutoModel,
pipeline, AutoConfig
)
from datasets import load_dataset
from bertviz import model_view, head_view, neuron_view
from bertviz.transformers_neuron_view import BertModel
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cpu
import html
import re
def load_sentences(n_samples=100, min_words=10, max_words=20):
"""Load sentences from AG News dataset"""
dataset = load_dataset("ag_news", split="train")
label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
sentences = []
domains = []
samples_per_category = n_samples // 4
for label in range(4):
filtered = dataset.filter(lambda x: x['label'] == label)
shuffled = filtered.shuffle(seed=42)
count = 0
for item in shuffled:
if count >= samples_per_category:
break
# decode HTML and clean text
text = html.unescape(item['text'])
text = re.sub(r'<[^>]+>', '', text) # Remove HTML tags
text = re.sub(r'\s+', ' ', text).strip() # Normalize whitespace
# Check length
word_count = len(text.split())
if min_words <= word_count <= max_words and len(text) > 20:
sentences.append(text)
domains.append(label_names[label])
count += 1
print(f"\n Loaded {len(sentences)} sentences from AG News:")
for label_name in label_names.values():
count = domains.count(label_name)
print(f" {label_name}: {count} sentences")
print(f"\n Sample sentences:")
for i in range(min(3, len(sentences))):
print(f" [{domains[i]}] {sentences[i][:80]}...")
return sentences, domains
TEST_SENTENCES, SENTENCE_DOMAINS = load_sentences(n_samples=100, min_words=10, max_words=20)
Loaded 100 sentences from AG News: World: 25 sentences Sports: 25 sentences Business: 25 sentences Sci/Tech: 25 sentences Sample sentences: [World] Somalis vie to be new president Twenty-eight candidates are approved to contest ... [World] Agency pleads for hostage release Care International appeals on Arabic televisio... [World] Clinton recovering after heart op Former US President Bill Clinton's heart bypas...
MODELS = {
"BERT": {
"name": "bert-base-uncased",
"description": "Original BERT",
"color": "#FF6B6B"
},
"DistilBERT": {
"name": "distilbert-base-uncased",
"description": "Compressed BERT",
"color": "#4ECDC4"
},
"RoBERTa": {
"name": "roberta-base",
"description": "Optimized BERT",
"color": "#45B7D1"
}
}
def load_model_and_tokenizer(model_name):
"""
Load model and tokenizer for a given model name.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name,
output_attentions=True,
return_dict=True).to(device)
config = AutoConfig.from_pretrained(model_name)
print(f"Modèle {model_name} loaded")
return model, tokenizer,config
models_data = {}
for model_name, model_info in MODELS.items():
model, tokenizer, config = load_model_and_tokenizer(model_info["name"])
models_data[model_name] = {
"model": model,
"tokenizer": tokenizer,
"config": config,
"description": model_info["description"],
"color": model_info["color"]
}
print("✅ All models loaded")
Modèle bert-base-uncased loaded Modèle distilbert-base-uncased loaded
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Modèle roberta-base loaded ✅ All models loaded
def analyze_model_architecture(models_data):
"""Compare architecture of different models."""
arch_data = []
for model_name, data in models_data.items():
config = data["config"]
model = data["model"]
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
arch_info = {
"Model": model_name,
"Layers": config.num_hidden_layers,
"Hidden Size": config.hidden_size,
"Attention Heads": config.num_attention_heads,
"Vocab Size": config.vocab_size,
"Total Params": f"{total_params:,}",
"Trainable Params": f"{trainable_params:,}",
}
arch_data.append(arch_info)
print(f"\n{model_name}:")
print(f" Layers: {config.num_hidden_layers}")
print(f" Hidden Size: {config.hidden_size}")
print(f" Attention Heads: {config.num_attention_heads}")
print(f" Vocab Size: {config.vocab_size:,}")
print(f" Parameters: {total_params:,}")
return pd.DataFrame(arch_data)
architecture_df = analyze_model_architecture(models_data)
print("\n Comparison table:")
display(architecture_df)
BERT: Layers: 12 Hidden Size: 768 Attention Heads: 12 Vocab Size: 30,522 Parameters: 109,482,240 DistilBERT: Layers: 6 Hidden Size: 768 Attention Heads: 12 Vocab Size: 30,522 Parameters: 66,362,880 RoBERTa: Layers: 12 Hidden Size: 768 Attention Heads: 12 Vocab Size: 50,265 Parameters: 124,645,632 Comparison table:
| Model | Layers | Hidden Size | Attention Heads | Vocab Size | Total Params | Trainable Params | |
|---|---|---|---|---|---|---|---|
| 0 | BERT | 12 | 768 | 12 | 30522 | 109,482,240 | 109,482,240 |
| 1 | DistilBERT | 6 | 768 | 12 | 30522 | 66,362,880 | 66,362,880 |
| 2 | RoBERTa | 12 | 768 | 12 | 50265 | 124,645,632 | 124,645,632 |
def analyze_sentence_attention(sentence, model_name, models_data):
"""Analyze attention patterns for a given sentence and model."""
print(f"Analyze attention for: '{sentence}'")
print(f"Model: {model_name}")
tokenizer = models_data[model_name]["tokenizer"]
model = models_data[model_name]["model"]
inputs = tokenizer(sentence, return_tensors="pt", truncation=True)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print(f"Tokens ({len(tokens)}): {tokens[:10]}{'...' if len(tokens) > 10 else ''}")
# Prediction with attention
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
seq_length = attentions[0].shape[-1]
print(f"Attention shape: {num_layers} layers, {num_heads} heads, {seq_length} tokens")
return {
"tokens": tokens,
"attentions": attentions,
"inputs": inputs,
"num_layers": num_layers,
"num_heads": num_heads
}
test_sentence = TEST_SENTENCES[1]
print(f" Test on: '{test_sentence}'\n")
attention_results = {}
for model_name in models_data.keys():
attention_results[model_name] = analyze_sentence_attention(
test_sentence, model_name, models_data
)
print("-" * 50)
Test on: 'Agency pleads for hostage release Care International appeals on Arabic television for the release of its Iraq director, Margaret Hassan.' Analyze attention for: 'Agency pleads for hostage release Care International appeals on Arabic television for the release of its Iraq director, Margaret Hassan.' Model: BERT Tokens (25): ['[CLS]', 'agency', 'plead', '##s', 'for', 'hostage', 'release', 'care', 'international', 'appeals']... Attention shape: 12 layers, 12 heads, 25 tokens -------------------------------------------------- Analyze attention for: 'Agency pleads for hostage release Care International appeals on Arabic television for the release of its Iraq director, Margaret Hassan.' Model: DistilBERT Tokens (25): ['[CLS]', 'agency', 'plead', '##s', 'for', 'hostage', 'release', 'care', 'international', 'appeals']... Attention shape: 6 layers, 12 heads, 25 tokens -------------------------------------------------- Analyze attention for: 'Agency pleads for hostage release Care International appeals on Arabic television for the release of its Iraq director, Margaret Hassan.' Model: RoBERTa Tokens (26): ['<s>', 'A', 'gency', 'Ġple', 'ads', 'Ġfor', 'Ġhostage', 'Ġrelease', 'ĠCare', 'ĠInternational']... Attention shape: 12 layers, 12 heads, 26 tokens --------------------------------------------------
from IPython.display import display, HTML
def create_attention_visualizations(sentence, model_name, models_data):
"""Create attention visualizations with BertViz"""
tokenizer = models_data[model_name]["tokenizer"]
model = models_data[model_name]["model"]
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=512)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions
print(f"Tokens: {len(tokens)} tokens")
print(f"Attention: {len(attentions)} layers, {attentions[0].shape[1]} heads")
print("Model View (all attention heads):")
model_view(attentions, tokens)
print("Head View (per head detail):")
head_view(attentions, tokens)
return {
"tokens": tokens,
"attentions": attentions,
"num_layers": len(attentions),
"num_heads": attentions[0].shape[1]
}
visualization_results = {}
for model_name in models_data.keys():
print(f"\n{'='*60}")
print(f"VISUALIZING {model_name}")
print(f"{'='*60}")
result = create_attention_visualizations(test_sentence, model_name, models_data)
visualization_results[model_name] = result
============================================================ VISUALIZING BERT ============================================================ Tokens: 25 tokens Attention: 12 layers, 12 heads Model View (all attention heads):
Head View (per head detail):
============================================================ VISUALIZING DistilBERT ============================================================ Tokens: 25 tokens Attention: 6 layers, 12 heads Model View (all attention heads):
Head View (per head detail):
============================================================ VISUALIZING RoBERTa ============================================================ Tokens: 26 tokens Attention: 12 layers, 12 heads Model View (all attention heads):
Head View (per head detail):
Understanding Attention Patterns¶
What are Layers and Heads?¶
Layers are like steps in processing - each layer refines the understanding:
- Early layers (0-3): Focus on grammar and word relationships
- Middle layers (4-8): Build meaning and context
- Final layers (9-11): Aggregate information for the final output
Attention Heads decide which words are important. Think of it like each word "looking at" other words to understand context. With 12 heads per layer, the model examines different aspects simultaneously.
Reading the Visualizations¶
Model View: Shows which words pay attention to which. Brighter colors = stronger attention.
Head View: A grid showing all attention heads across all layers. Each small matrix shows the attention pattern for one head in one layer.
Comparing the Models¶
BERT - Baseline with hierarchical patterns from syntax to semantics
DistilBERT - Compressed to 6 layers but maintains effectiveness with more focused patterns
RoBERTa - Optimized training leads to cleaner, more targeted attention patterns
def analyze_attention_patterns(sentence, models_data, display_output=True):
"""Analyze attention patterns"""
if display_output:
print(f"ANALYSE PATTERNS: '{sentence}'")
print("=" * 60)
patterns_analysis = {}
for model_name, data in models_data.items():
if display_output:
print(f"\n{model_name}:")
tokenizer = data["tokenizer"]
model = data["model"]
inputs = tokenizer(sentence, return_tensors="pt", return_offsets_mapping=False)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions
last_attention = attentions[-1][0] # [heads, seq_len, seq_len]
# 1. Attention to [CLS] token
cls_attention = last_attention[:, 0, :].mean(dim=0) # Mean over heads
# 2. Self-attention (attention of a token to itself)
self_attention = torch.diagonal(last_attention.mean(dim=0))
# 3. Attention to content words vs function words
content_words = []
function_words = []
for i, token in enumerate(tokens):
if token.startswith('##') or token in ['[CLS]', '[SEP]', '[PAD]']:
continue
elif token.lower() in ['the', 'a', 'an', 'is', 'are', 'was', 'were', 'of', 'in', 'on', 'at']:
function_words.append(i)
else:
content_words.append(i)
avg_attention_to_content = last_attention.mean(dim=0)[:, content_words].mean() if content_words else 0
avg_attention_to_function = last_attention.mean(dim=0)[:, function_words].mean() if function_words else 0
patterns = {
"cls_attention_max": float(cls_attention.max()),
"cls_attention_mean": float(cls_attention.mean()),
"self_attention_mean": float(self_attention.mean()),
"content_vs_function": float(avg_attention_to_content / max(avg_attention_to_function, 0.001)),
"attention_entropy": float(-torch.sum(last_attention.mean(dim=0) * torch.log(last_attention.mean(dim=0) + 1e-10)).mean())
}
patterns_analysis[model_name] = patterns
if display_output:
print(f" CLS attention max: {patterns['cls_attention_max']:.3f}")
print(f" Self-attention mean: {patterns['self_attention_mean']:.3f}")
print(f" Content/Function ratio: {patterns['content_vs_function']:.2f}")
print(f" Attention entropy: {patterns['attention_entropy']:.3f}")
return patterns_analysis
all_patterns = {}
for i, sentence in enumerate(TEST_SENTENCES):
display = (i + 1) % 10 == 0
if display:
print(f"\n{'='*20} SENTENCE {i+1} {'='*20}")
patterns = analyze_attention_patterns(sentence, models_data, display_output=display)
all_patterns[f"Sentence_{i+1}"] = patterns
==================== SENTENCE 10 ==================== ANALYSE PATTERNS: 'Painkiller risk to gut revealed The risk of intestinal damage from common painkillers may be higher than thought, research suggests.' ============================================================ BERT: CLS attention max: 0.144 Self-attention mean: 0.109 Content/Function ratio: 2.25 Attention entropy: 63.005 DistilBERT: CLS attention max: 0.176 Self-attention mean: 0.082 Content/Function ratio: 2.47 Attention entropy: 62.340 RoBERTa: CLS attention max: 0.263 Self-attention mean: 0.123 Content/Function ratio: 38.46 Attention entropy: 50.159 ==================== SENTENCE 20 ==================== ANALYSE PATTERNS: 'Cambodia set to crown new king Cambodians prepare for the coronation of King Sihamoni, amid an array of official festivities.' ============================================================ BERT: CLS attention max: 0.106 Self-attention mean: 0.087 Content/Function ratio: 3.80 Attention entropy: 58.372 DistilBERT: CLS attention max: 0.141 Self-attention mean: 0.080 Content/Function ratio: 3.51 Attention entropy: 60.281 RoBERTa: CLS attention max: 0.266 Self-attention mean: 0.119 Content/Function ratio: 32.26 Attention entropy: 61.734 ==================== SENTENCE 30 ==================== ANALYSE PATTERNS: 'Quincy gets its revenge It took a year, but Quincy's volleyball team has bragging rights in the city again.' ============================================================ BERT: CLS attention max: 0.116 Self-attention mean: 0.099 Content/Function ratio: 3.37 Attention entropy: 57.322 DistilBERT: CLS attention max: 0.161 Self-attention mean: 0.087 Content/Function ratio: 2.74 Attention entropy: 61.465 RoBERTa: CLS attention max: 0.278 Self-attention mean: 0.144 Content/Function ratio: 2.06 Attention entropy: 52.182 ==================== SENTENCE 40 ==================== ANALYSE PATTERNS: 'Tyson Completes Service Charges stemming from a 2003 altercation are dropped as Mike Tyson completes community service on Wednesday.' ============================================================ BERT: CLS attention max: 0.114 Self-attention mean: 0.107 Content/Function ratio: 2.47 Attention entropy: 41.588 DistilBERT: CLS attention max: 0.187 Self-attention mean: 0.096 Content/Function ratio: 1.83 Attention entropy: 48.059 RoBERTa: CLS attention max: 0.279 Self-attention mean: 0.149 Content/Function ratio: 40.00 Attention entropy: 50.392 ==================== SENTENCE 50 ==================== ANALYSE PATTERNS: 'Martinez Deal Finalized Martinez passes his physical, and the Mets finalize their \$53 million, four-year contract with the pitcher.' ============================================================ BERT: CLS attention max: 0.114 Self-attention mean: 0.092 Content/Function ratio: 4.21 Attention entropy: 59.085 DistilBERT: CLS attention max: 0.180 Self-attention mean: 0.078 Content/Function ratio: 2.12 Attention entropy: 68.812 RoBERTa: CLS attention max: 0.259 Self-attention mean: 0.150 Content/Function ratio: 32.26 Attention entropy: 61.319 ==================== SENTENCE 60 ==================== ANALYSE PATTERNS: 'Rising material costs hit Heinz Second quarter profits at ketchup maker Heinz are hit by higher material and transport costs.' ============================================================ BERT: CLS attention max: 0.132 Self-attention mean: 0.120 Content/Function ratio: 1.54 Attention entropy: 49.509 DistilBERT: CLS attention max: 0.252 Self-attention mean: 0.087 Content/Function ratio: 1.71 Attention entropy: 48.379 RoBERTa: CLS attention max: 0.282 Self-attention mean: 0.126 Content/Function ratio: 37.04 Attention entropy: 51.436 ==================== SENTENCE 70 ==================== ANALYSE PATTERNS: 'ADV: \$150,000 Mortgage for Under \$690/Month Mortgage rates are at record lows. Save \$1000s on your mortgage payment. Free quotes.' ============================================================ BERT: CLS attention max: 0.090 Self-attention mean: 0.103 Content/Function ratio: 1.56 Attention entropy: 95.987 DistilBERT: CLS attention max: 0.144 Self-attention mean: 0.072 Content/Function ratio: 1.42 Attention entropy: 88.378 RoBERTa: CLS attention max: 0.213 Self-attention mean: 0.127 Content/Function ratio: 26.32 Attention entropy: 84.686 ==================== SENTENCE 80 ==================== ANALYSE PATTERNS: 'Reg readers name BSA antipiracy weasel Poll result The people have spoken' ============================================================ BERT: CLS attention max: 0.114 Self-attention mean: 0.150 Content/Function ratio: 0.96 Attention entropy: 28.237 DistilBERT: CLS attention max: 0.281 Self-attention mean: 0.109 Content/Function ratio: 0.79 Attention entropy: 28.323 RoBERTa: CLS attention max: 0.359 Self-attention mean: 0.147 Content/Function ratio: 55.56 Attention entropy: 32.112 ==================== SENTENCE 90 ==================== ANALYSE PATTERNS: 'Last Xmas order date for the Antipodes Cash'n'Carrion Get 'em in by Sunday' ============================================================ BERT: CLS attention max: 0.200 Self-attention mean: 0.127 Content/Function ratio: 0.91 Attention entropy: 43.208 DistilBERT: CLS attention max: 0.289 Self-attention mean: 0.104 Content/Function ratio: 0.85 Attention entropy: 44.560 RoBERTa: CLS attention max: 0.358 Self-attention mean: 0.136 Content/Function ratio: 40.00 Attention entropy: 47.427 ==================== SENTENCE 100 ==================== ANALYSE PATTERNS: 'Will historic flight launch space tourism? Regardless, space competitions are poised to become big business.' ============================================================ BERT: CLS attention max: 0.091 Self-attention mean: 0.118 Content/Function ratio: 2.11 Attention entropy: 43.097 DistilBERT: CLS attention max: 0.189 Self-attention mean: 0.093 Content/Function ratio: 1.71 Attention entropy: 42.230 RoBERTa: CLS attention max: 0.285 Self-attention mean: 0.132 Content/Function ratio: 50.00 Attention entropy: 39.187
def visualize_attention_patterns(all_patterns, sentence_domains):
"""Visualize attention patterns across models and sentences"""
metrics = ["cls_attention_max", "self_attention_mean", "content_vs_function", "attention_entropy"]
metric_names = ["CLS Attention Max", "Self-Attention Mean", "Content/Function Ratio", "Entropy"]
models = list(models_data.keys())
sentence_keys = list(all_patterns.keys())
# 1. GRAPHICS BY METRIC (overall averages per model)
print("\n OVERALL AVERAGE PATTERNS:")
fig = make_subplots(
rows=2, cols=2,
subplot_titles=metric_names,
specs=[[{"type": "bar"}, {"type": "bar"}],
[{"type": "bar"}, {"type": "bar"}]]
)
positions = [(1,1), (1,2), (2,1), (2,2)]
for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
row, col = positions[idx]
model_means = {}
for model in models:
values = [all_patterns[sentence_key][model][metric] for sentence_key in sentence_keys]
model_means[model] = np.mean(values)
colors = [MODELS[model]["color"] for model in models]
fig.add_trace(
go.Bar(
x=list(model_means.keys()),
y=list(model_means.values()),
marker_color=colors,
name=metric_name,
showlegend=False
),
row=row, col=col
)
fig.update_layout(
title_text="Attention Patterns per Model (Overall Averages)",
title_x=0.5,
height=600
)
fig.show()
# 2. TABLE BY DOMAIN
print("\n TABLE BY DOMAIN:")
domain_data = []
for sentence_key, domain in zip(sentence_keys, sentence_domains):
for model in models:
for metric in metrics:
domain_data.append({
"Domain": domain,
"Model": model,
"Metric": metric,
"Value": all_patterns[sentence_key][model][metric]
})
domain_df = pd.DataFrame(domain_data)
pivot_table = domain_df.pivot_table(
values='Value',
index=['Domain', 'Model'],
columns='Metric',
aggfunc='mean'
).round(3)
print("\nAverage metrics by domain and model:")
display(pivot_table)
# 3. TABLE BY SENTENCES
print("\n TABLE BY SENTENCES:")
pattern_data = []
for i, (sentence_key, domain) in enumerate(zip(sentence_keys, sentence_domains)):
for model in models:
row = {
"Sentence": f"S{i+1}",
"Domain": domain,
"Model": model
}
for metric in metrics:
row[metric] = f"{all_patterns[sentence_key][model][metric]:.3f}"
pattern_data.append(row)
pattern_df = pd.DataFrame(pattern_data)
display(pattern_df)
# 4. COMPARATIVE CHART BY DOMAIN
print("\n DOMAIN COMPARISON:")
fig_domain = make_subplots(
rows=2, cols=2,
subplot_titles=metric_names,
specs=[[{"type": "bar"}, {"type": "bar"}],
[{"type": "bar"}, {"type": "bar"}]]
)
unique_domains = sorted(set(sentence_domains))
for idx, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
row, col = positions[idx]
for model in models:
domain_means = []
for domain in unique_domains:
# Filter data for this domain
domain_values = [
all_patterns[sentence_key][model][metric]
for sentence_key, d in zip(sentence_keys, sentence_domains)
if d == domain
]
domain_means.append(np.mean(domain_values))
fig_domain.add_trace(
go.Bar(
x=unique_domains,
y=domain_means,
name=model,
marker_color=MODELS[model]["color"],
showlegend=(idx == 0) # Show legend only in first subplot
),
row=row, col=col
)
fig_domain.update_layout(
title_text="Attention Patterns by Domain",
title_x=0.5,
height=700,
barmode='group'
)
fig_domain.show()
return fig, pattern_df, pivot_table
patterns_fig, patterns_df, domain_pivot = visualize_attention_patterns(all_patterns, SENTENCE_DOMAINS)
OVERALL AVERAGE PATTERNS:
TABLE BY DOMAIN: Average metrics by domain and model:
| Metric | attention_entropy | cls_attention_max | content_vs_function | self_attention_mean | |
|---|---|---|---|---|---|
| Domain | Model | ||||
| Business | BERT | 55.365 | 0.111 | 3.344 | 0.105 |
| DistilBERT | 58.253 | 0.197 | 3.143 | 0.083 | |
| RoBERTa | 53.341 | 0.267 | 31.708 | 0.133 | |
| Sci/Tech | BERT | 50.914 | 0.123 | 7.479 | 0.112 |
| DistilBERT | 52.223 | 0.211 | 7.841 | 0.089 | |
| RoBERTa | 51.451 | 0.289 | 38.574 | 0.129 | |
| Sports | BERT | 61.612 | 0.106 | 2.803 | 0.101 |
| DistilBERT | 65.759 | 0.173 | 2.013 | 0.083 | |
| RoBERTa | 59.828 | 0.269 | 27.725 | 0.130 | |
| World | BERT | 55.811 | 0.127 | 5.519 | 0.104 |
| DistilBERT | 57.042 | 0.182 | 5.542 | 0.082 | |
| RoBERTa | 54.705 | 0.274 | 33.588 | 0.128 |
TABLE BY SENTENCES:
| Sentence | Domain | Model | cls_attention_max | self_attention_mean | content_vs_function | attention_entropy | |
|---|---|---|---|---|---|---|---|
| 0 | S1 | World | BERT | 0.108 | 0.088 | 1.349 | 61.730 |
| 1 | S1 | World | DistilBERT | 0.171 | 0.070 | 0.888 | 62.514 |
| 2 | S1 | World | RoBERTa | 0.281 | 0.128 | 3.808 | 60.199 |
| 3 | S2 | World | BERT | 0.102 | 0.112 | 1.587 | 53.609 |
| 4 | S2 | World | DistilBERT | 0.196 | 0.083 | 2.178 | 50.663 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 295 | S99 | Sci/Tech | DistilBERT | 0.219 | 0.072 | 1.799 | 58.843 |
| 296 | S99 | Sci/Tech | RoBERTa | 0.287 | 0.140 | 38.462 | 54.408 |
| 297 | S100 | Sci/Tech | BERT | 0.091 | 0.118 | 2.107 | 43.097 |
| 298 | S100 | Sci/Tech | DistilBERT | 0.189 | 0.093 | 1.711 | 42.230 |
| 299 | S100 | Sci/Tech | RoBERTa | 0.285 | 0.132 | 50.000 | 39.187 |
300 rows × 7 columns
DOMAIN COMPARISON:
import time
import psutil
import gc
def benchmark_model_performance(models_data, test_sentences):
"""Benchmark speed and memory usage of models."""
print("BENCHMARK PERFORMANCE MODELS")
print("=" * 50)
results = []
for model_name, data in models_data.items():
print(f"\nTest {model_name}...")
tokenizer = data["tokenizer"]
model = data["model"]
times = []
memory_usage = []
for sentence in test_sentences:
gc.collect()
memory_before = psutil.Process().memory_info().rss / 1024 / 1024 # MB
start_time = time.time()
with torch.no_grad():
inputs = tokenizer(sentence, return_tensors="pt", truncation=True)
outputs = model(**inputs)
end_time = time.time()
inference_time = end_time - start_time
times.append(inference_time)
memory_after = psutil.Process().memory_info().rss / 1024 / 1024 # MB
memory_usage.append(memory_after - memory_before)
avg_time = np.mean(times)
std_time = np.std(times)
avg_memory = np.mean(memory_usage)
total_params = sum(p.numel() for p in model.parameters())
result = {
"Model": model_name,
"Avg Time (ms)": f"{avg_time*1000:.2f}",
"Écart-type (ms)": f"{std_time*1000:.2f}",
"Memory (MB)": f"{avg_memory:.1f}",
"Parameters": f"{total_params:,}",
"Speed Relative": 1.0 # Will be calculated later
}
results.append(result)
print(f" Avg Time: {avg_time*1000:.2f} ms")
print(f" Memory: {avg_memory:.1f} MB")
print(f" Parameters: {total_params:,}")
# Speed relative (BERT = baseline)
bert_time = float(results[0]["Avg Time (ms)"].replace(' ms', ''))
for result in results:
model_time = float(result["Avg Time (ms)"].replace(' ms', ''))
result["Speed Relative"] = f"{bert_time/model_time:.2f}x"
return pd.DataFrame(results)
performance_df = benchmark_model_performance(models_data, TEST_SENTENCES)
print("\n BENCHMARK RESULTS:")
display(performance_df)
BENCHMARK PERFORMANCE MODELS ================================================== Test BERT... Avg Time: 44.25 ms Memory: 0.5 MB Parameters: 109,482,240 Test DistilBERT... Avg Time: 22.20 ms Memory: 0.0 MB Parameters: 66,362,880 Test RoBERTa... Avg Time: 44.05 ms Memory: 0.4 MB Parameters: 124,645,632 BENCHMARK RESULTS:
| Model | Avg Time (ms) | Écart-type (ms) | Memory (MB) | Parameters | Speed Relative | |
|---|---|---|---|---|---|---|
| 0 | BERT | 44.25 | 3.65 | 0.5 | 109,482,240 | 1.00x |
| 1 | DistilBERT | 22.20 | 0.98 | 0.0 | 66,362,880 | 1.99x |
| 2 | RoBERTa | 44.05 | 2.67 | 0.4 | 124,645,632 | 1.00x |
def create_performance_charts(performance_df):
"""Create performance comparison charts"""
models = performance_df["Model"].values
times = [float(x.replace(' ms', '')) for x in performance_df["Avg Time (ms)"].values]
memory = [float(x.replace(' MB', '')) for x in performance_df["Memory (MB)"].values]
params = [int(x.replace(',', '')) for x in performance_df["Parameters"].values]
colors = [MODELS[model]["color"] for model in models]
fig = make_subplots(
rows=2, cols=2,
subplot_titles=("Inference Time", "Memory Usage",
"Parameter Count", "Speed vs Parameters"),
specs=[[{"type": "bar"}, {"type": "bar"}],
[{"type": "bar"}, {"type": "scatter"}]]
)
# Chart 1: Time
fig.add_trace(
go.Bar(x=models, y=times, marker_color=colors, name="Time", showlegend=False),
row=1, col=1
)
# Chart 2: Memory
fig.add_trace(
go.Bar(x=models, y=memory, marker_color=colors, name="Memory", showlegend=False),
row=1, col=2
)
# Chart 3: Parameters
fig.add_trace(
go.Bar(x=models, y=params, marker_color=colors, name="Parameters", showlegend=False),
row=2, col=1
)
# Chart 4: Trade-off
fig.add_trace(
go.Scatter(
x=times, y=params,
mode='markers+text',
marker=dict(size=15, color=colors),
text=models,
textposition="top center",
name="Trade-off",
showlegend=False
),
row=2, col=2
)
fig.update_layout(
title_text="Transformer Models Performance",
title_x=0.5,
height=800,
showlegend=False
)
fig.update_xaxes(title_text="Models", row=1, col=1)
fig.update_xaxes(title_text="Models", row=1, col=2)
fig.update_xaxes(title_text="Models", row=2, col=1)
fig.update_xaxes(title_text="Time (ms)", row=2, col=2)
fig.update_yaxes(title_text="Time (ms)", row=1, col=1)
fig.update_yaxes(title_text="Memory (MB)", row=1, col=2)
fig.update_yaxes(title_text="Parameters", row=2, col=1)
fig.update_yaxes(title_text="Parameters", row=2, col=2)
fig.show()
return fig
performance_fig = create_performance_charts(performance_df)
Summary¶
We compared attention mechanisms across three Transformer architectures and found distinct patterns for each.
1. RoBERTa
2. DistilBERT
3. BERT
When to Use Each Model¶
- High performance tasks → RoBERTa
- Resource constraints → DistilBERT
- General purpose/exploration → BERT
Model Signatures¶
| Model | CLS Agg. | Self-Att. | Ratio C/F | Entropy | Speed |
|---|---|---|---|---|---|
| BERT | 1.00x | ||||
| DistilBERT | |||||
| RoBERTa |
import os
os.makedirs("results", exist_ok=True)
# 1. Sentences from the dataset
sentences_df = pd.DataFrame({
"Sentence_ID": [f"S{i+1}" for i in range(len(TEST_SENTENCES))],
"Domain": SENTENCE_DOMAINS,
"Sentence": TEST_SENTENCES,
"Word_Count": [len(s.split()) for s in TEST_SENTENCES]
})
sentences_df.to_csv("results/dataset_sentences.csv", index=False)
# 2. Metrics by domain
domain_pivot.to_csv("results/metrics_by_domain.csv")
# 3. Metrics by sentence
sentence_metrics = []
for i, (sentence_key, domain) in enumerate(zip(all_patterns.keys(), SENTENCE_DOMAINS)):
for model in models_data.keys():
row = {
"Sentence_ID": f"S{i+1}",
"Domain": domain,
"Model": model,
"CLS_Attention_Max": all_patterns[sentence_key][model]["cls_attention_max"],
"CLS_Attention_Mean": all_patterns[sentence_key][model]["cls_attention_mean"],
"Self_Attention_Mean": all_patterns[sentence_key][model]["self_attention_mean"],
"Content_Function_Ratio": all_patterns[sentence_key][model]["content_vs_function"],
"Attention_Entropy": all_patterns[sentence_key][model]["attention_entropy"]
}
sentence_metrics.append(row)
sentence_metrics_df = pd.DataFrame(sentence_metrics)
sentence_metrics_df.to_csv("results/metrics_by_sentence.csv", index=False)